Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445
Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445apsonawane wants to merge 4 commits into
Conversation
MatMulNBits::PrePack ran at session initialization and called the MLAS
pack routines using byte counts derived from the node attributes
(N, K, bits, block_size) without ever comparing those attributes to
the actual tensor Shape(). A crafted .onnx whose attributes overstate
the real B (or scales / zero_points) extent triggered a
heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData /
MlasLutGemmPack during OrtApis::CreateSession (no Run() required).
The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute()
-- after PrePack has already done the OOB read, and by then the
original B tensor is replaced with nullptr in the kernel context so
the Compute-time check never re-validates it.
Fix: at the top of PrePack, after the existing early-return guards
and before any tensor.DataRaw() read, validate the incoming
initializer's Shape() against the attribute-derived shape:
- B -> (N, k_blocks, blob_size)
- scales -> (N * k_blocks) or (N, k_blocks)
- zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
(N * k_blocks) or (N, k_blocks)
A mismatch returns INVALID_ARGUMENT so the session fails to load
rather than reading past the buffer.
There was a problem hiding this comment.
Pull request overview
This PR hardens the CPU MatMulNBits contrib op against malformed models by adding early shape validation in MatMulNBits<T1>::PrePack() so that session initialization rejects inconsistent initializers before any MLAS packing routine can read past the provided buffers.
Changes:
- Add attribute-derived initializer shape checks for
B,scales, andzero_pointsat the top ofMatMulNBits<T1>::PrePack(). - Add new unit tests that expect session creation to fail (pre-
Compute()) for mismatched initializer shapes, plus a compatibility test for legacy flattenedscales/zero_pointslayouts.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Adds new PrePack-time shape validation intended to prevent OOB reads during weight packing at session init. |
| onnxruntime/test/contrib_ops/matmul_4bits_test.cc | Adds tests that exercise PrePack-time rejection for malformed initializer shapes and verifies legacy flattened layouts remain accepted. |
tianleiwu
left a comment
There was a problem hiding this comment.
Review: Validate B/scales/zero_points shape in MatMulNBits::PrePack
Verdict: LGTM (COMMENT). This correctly closes a real heap-buffer-overflow READ reachable at CreateSession time (no Run() required). The analysis holds up:
- The validation block is placed at the very top of
PrePack, afteris_packed = falseand before everytensor.DataRaw()/ constant-tensor read on all paths (LUT and non-LUT, x64 and ARM64). This is the key correctness property and it is satisfied. - Running it ahead of the
has_g_idx_/ unquantized-ZP /!MlasIsQNBitGemmAvailableearly-returns is the right call — it makes bad-shape models fail consistently even on configs (e.g. Win x86 32-bit) wherePrePackwould otherwise short-circuit and the B tensor is dropped beforeCompute()'sCheckInputsruns. - Validating the constant
scales/zero_pointsduring the B prepack (viaTryGetConstantInput) is necessary and correct: the B pack path dereferences those tensors before their ownPrePackcalls run, so per-tensor validation alone would be too late. The gating (has_zp_arg_ && has_zp_input_) matches the conditions under which the pack routines actually read ZP, so no over- or under-validation. - The derived shapes match
matmul_nbits_helper::CheckInputsexactly: B(N, k_blocks, blob_size), scales[N*k_blocks]/[N,k_blocks], uint8 ZP[N*zp_blob]/[N,zp_blob]withzp_blob=(k_blocks*bits+7)/8, else[N*k_blocks]/[N,k_blocks].INVALID_ARGUMENTreturn makes session load fail cleanly. - Test coverage is good: bad B extent, wrong B rank, bad scales, bad uint8 ZP (all expect
"MatMulNBits PrePack:"failure beforeCompute()), plus a positive test that legacy flattened 1D scales/ZP layouts still load and compute correctly (no backward-compat regression).
Two minor, non-blocking observations below.
Nitpick (completeness, not a security gap): For non-uint8 (float) zero_points, CheckInputs additionally rejects a zero_points whose element type differs from scales; this guard checks only the shape. It is not an OOB risk (the LUT float-ZP path dispatches on the ZP's own dtype and enforces its size), and Compute()'s CheckInputs still catches the dtype mismatch — so this is purely a note, no change required.
| // the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads | ||
| // from the (N, K, bits, block_size) attributes; without this check a crafted model whose | ||
| // attributes overstate the real tensor extents would trigger a heap-buffer-overflow READ at | ||
| // session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked |
There was a problem hiding this comment.
Suggestion (maintainability): the shape math here (k_blocks, blob_size, zp_blob_size_uint8) and the accepted layouts duplicate matmul_nbits_helper::CheckInputs. The cross-reference comment helps, but the two can silently drift if the canonical layout ever changes (e.g. a new packing scheme). Since these are constexpr-style derivations, consider factoring the layout math into a small shared helper in matmul_nbits_helper.h that both this guard and CheckInputs call, so a future layout change updates one place. (Reusing CheckInputs directly would change the error strings the new tests assert on, so a shared derivation helper is the lower-friction option.) Non-blocking.
| "MatMulNBits PrePack: zero_points initializer shape ", s, | ||
| " does not match attribute-derived shape [", n * zp_blob_size_uint8, "] or [", | ||
| n, ",", zp_blob_size_uint8, "]"); | ||
| } else { |
There was a problem hiding this comment.
Nitpick: this non-uint8 branch validates the ZP shape but, unlike CheckInputs, does not verify zero_points and scales share the same element type. Not a security gap (no OOB stems from the dtype alone, and Compute()'s CheckInputs still rejects it), so this is just a completeness note — no change required.
MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required).
The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it.
Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape:
(N * k_blocks) or (N, k_blocks)
A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer.